-- Copy routines to stream multiple copies in parallel -*- lua -*-
--
-- For big hash tables, it is expected that we will have to go all the
-- way out to main memory every time you go to look up a value.  That's
-- pretty expensive: around 70 nanoseconds per cache miss.  We can reduce
-- this cost by making many fetches at once, and relying on the processor
-- to parallelize the requests.  In this way although we can expect
-- the latency for one lookup operation to be approximately the cost
-- of one cache miss, if the lookup returns N values the per-lookup
-- cost is divided by N.
--
-- See http://www.realworldtech.com/haswell-cpu/5/ for more on how the
-- memory subsystem works in a modern processor.

module(..., package.seeall)

local debug = false

local ffi = require("ffi")
local C = ffi.C

local dasm = require("dasm")

|.arch x64
|.actionlist actions

-- Table keeping machine code alive to the GC.
local anchor = {}

-- Utility: assemble code and optionally dump disassembly.
local function assemble (name, prototype, generator)
   local Dst = dasm.new(actions)
   generator(Dst)
   local mcode, size = Dst:build()
   table.insert(anchor, mcode)
   if debug then
      print("mcode dump: "..name)
      dasm.dump(mcode, size)
   end
   return ffi.cast(prototype, mcode)
end

function gen(count, size)
   local function gen_multi_copy(Dst)
      -- dst in rdi
      -- src in rsi

      | vzeroall
      | push r12
      | push r13
      | push r14
      | push r15

      local tail_size = size % 32
      local tail_mask
      if tail_size ~= 0 then
         assert(tail_size % 4 == 0, '4-byte alignment required')
         tail_mask = ffi.new("uint8_t[32]")
         for i=0,tail_size-1 do tail_mask[i]=255 end
         table.insert(anchor, tail_mask)
         | mov rax, tail_mask
         | vmovdqu ymm15, [rax]
      end

      -- Stream in data from up to 8 regions at once.
      while count > 0 do
         local stride = math.min(count, 8)
         local to_copy = size
         for i = 0, stride-1 do
            | mov Rq(8+i), [rsi + 8*i]
         end
         while to_copy >= 32 do
            local double_copy = to_copy >= 64 and not tail_mask
            local inc = double_copy and 64 or 32
            for i = 0, stride-1 do
               | vmovdqu ymm(i), [Rq(8+i)]
               | add Rq(8+i), 32
               if double_copy then
                  | vmovdqu ymm(8+i), [Rq(8+i)]
                  | add Rq(8+i), 32
               end
            end
            for i = 0, stride-1 do
               | vmovdqu [rdi + i*size], ymm(i)
               if double_copy then
                  | vmovdqu [rdi + i*size+32], ymm(8+i)
               end
            end
            | add rdi, inc
            to_copy = to_copy - inc
         end

         if to_copy > 0 then
            for i = 0, stride-1 do
               | vmaskmovps ymm(i), ymm15, [Rq(8+i)]
            end
            for i = 0, stride-1 do
               | vmaskmovps [rdi + i*size], ymm15, ymm(i)
            end
            | add rdi, to_copy
            to_copy = 0
         end

         -- Now the dst has been advanced by SIZE.  Increment for the
         -- parallel strides.
         | add rdi, (stride-1)*size
         -- Increment the src as well.
         | add rsi, stride*8
         count = count - stride
      end
      | vzeroall
      | pop r15
      | pop r14
      | pop r13
      | pop r12
      | ret
   end

   return assemble("multi_copy_"..size,
                   "void(*)(void*, void*)",
                   gen_multi_copy)
end

function selftest ()
   print("selftest: multi_copy")

   local cpuinfo = require('core.lib').readfile("/proc/cpuinfo", "*a")
   assert(cpuinfo, "failed to read /proc/cpuinfo for hardware check")
   if not cpuinfo:match("avx2") then
      print("selftest: not supported; avx2 unavailable")
      return
   end

   local src = ffi.new('uint8_t[78]',
                       { 1,
                         2, 2,
                         3, 3, 3,
                         4, 4, 4, 4,
                         5, 5, 5, 5, 5, -- o/~ golden rings o/~
                         6, 6, 6, 6, 6, 6,
                         7, 7, 7, 7, 7, 7, 7,
                         8, 8, 8, 8, 8, 8, 8, 8,
                         9, 9, 9, 9, 9, 9, 9, 9, 9,
                         10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
                         11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
                         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12 })
   for size=4,76,4 do
      for count=1,10 do
         local dst = ffi.new('uint8_t['..100*count..']')
         local srcv = ffi.new('void*['..count..']')
         local multi_copy = gen(count, size)
         for offset=0,(78 - size - count)-1 do
            ffi.C.memset(dst, 0, 100*count)
            for i=0,count-1 do srcv[i] = src + offset + i end
            multi_copy(dst, srcv)
            for i=0,count-1 do
               for j=0,size-1 do assert(dst[i*size + j] == src[offset+i+j]) end
            end
            for i=count*size,100*count-1 do assert(dst[i] == 0) end
         end
      end
   end

   print("selftest: ok")
end
